package top.lingkang.mm.orm;

import cn.hutool.core.io.FileUtil;
import cn.hutool.core.lang.Assert;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.SqlSession;
import top.lingkang.mm.error.MagicException;
import top.lingkang.mm.utils.MagicUtils;

import java.io.File;
import java.nio.charset.StandardCharsets;
import java.sql.Connection;
import java.util.List;

/**
 * @Author lingkang
 * @Date 2024/3/1 10:24
 */
public class MapperManageImpl extends BaseMapperManage implements MapperManage {

    private String protocol;

    public MapperManageImpl(Configuration configuration, SqlSession sqlSession) {
        super(configuration, sqlSession);
    }


    @Override
    public <T> QueryWrapper<T> createQuery(String sql, Class<T> resultClass) {
        return new QueryWrapper<>(sql, resultClass, this);
    }

    @Override
    public <T> QueryWrapper<T> createQuery(Class<T> resultClass, String sql) {
        return new QueryWrapper<>(sql, resultClass, this);
    }

    @Override
    public UpdateWrapper createUpdate(String sql) {
        return new UpdateWrapper(sql, this);
    }

    @Override
    public <T> List<T> selectAll(Class<T> entityClass) {
        String sqlId = getSqlId(entityClass, "selectAll");
        return sqlSession.selectList(sqlId);
    }

    @Override
    public <T> T selectById(Class<T> entityClass, Object id) {
        MagicEntity entity = getMagicEntity(entityClass);
        if (entity.getIdIndex() == -1)
            throw new MagicException("实体对象没有 @Id 注解: " + entityClass.getName());
        String sqlId = getSqlId(entity, "selectById");
        return sqlSession.selectOne(sqlId, id);
    }

    @Override
    public <T> boolean existsById(Class<T> entityClass, Object id) {
        MagicEntity entity = getMagicEntity(entityClass);
        if (entity.getIdIndex() == -1)
            throw new MagicException("实体对象没有 @Id 注解: " + entityClass.getName());
        String sqlId = getSqlId(entity, "existsById");
        return sqlSession.selectOne(sqlId, id);
    }

    @Override
    public int insert(Object entity) {
        return insert(entity, true);
    }

    @Override
    public int insert(Object entity, boolean insertNull) {
        Assert.notNull(entity, "插入对象不能为空");
        MagicEntity magicEntity = getMagicEntity(entity.getClass());
        // 检查设置id
        checkIdSet(magicEntity, entity, getIdGenerate());
        MagicEntityUtils.autoSetTime(magicEntity, entity, true);
        MagicEntityUtils.execPreUpdate(magicEntity, entity);
        int result;
        if (insertNull) {
            result = sqlSession.insert(getSqlId(magicEntity, magicEntity.getIdAnn() != null ? "insert" : "insertNotId"), entity);
        } else
            result = sqlSession.insert(
                    getSqlId(magicEntity, magicEntity.getIdAnn() != null ? "notInsertNull" : "notInsertNullNotId"),
                    MagicEntityUtils.getInsertNotNullParams(entity, magicEntity));
        MagicEntityUtils.execPostUpdate(magicEntity, entity);
        return result;
    }

    @Override
    public <T> int insertBatch(List<T> list) {
        if (list == null || list.isEmpty())
            throw new MagicException("插入对象列表不能空");
        MagicEntity magicEntity = getMagicEntity(list.get(0).getClass());
        MagicEntityUtils.autoSetTimeList(magicEntity, list, true);
        MagicEntityUtils.execPreUpdateList(magicEntity, list);
        int result = sqlSession.insert(getSqlId(magicEntity, magicEntity.getIdAnn() != null ? "insertList" : "insertListNotId"), list);
        MagicEntityUtils.execPostUpdateList(magicEntity, list);
        return result;
    }

    @Override
    public int updateById(Object entity) {
        return updateById(entity, true);
    }

    @Override
    public int updateById(Object entity, boolean updateNull) {
        Assert.isFalse(entity instanceof Class, "更新的数据类型错误: " + entity);
        MagicEntity magicEntity = getMagicEntity(entity.getClass());
        if (magicEntity.getIdAnn() == null)
            throw new MagicException("更新实体类没有 @Id 注解: " + magicEntity.getClazz().getName());
        MagicEntityUtils.autoSetTime(magicEntity, entity, false);
        MagicEntityUtils.execPreUpdate(magicEntity, entity);
        int result;
        if (updateNull)
            result = sqlSession.update(getSqlId(magicEntity, "updateById"), entity);
        else
            result = sqlSession.update(getSqlId(magicEntity, "updateNotNullById"),
                    getUpdateNotNullParams(entity, magicEntity));
        MagicEntityUtils.execPostUpdate(magicEntity, entity);
        return result;
    }

    @Override
    public int deleteById(Object entity) {
        Assert.isFalse(entity instanceof Class, "更新的数据类型错误: " + entity);
        MagicEntity magicEntity = getMagicEntity(entity.getClass());
        if (magicEntity.getIdAnn() == null)
            throw new MagicException("删除实体类没有 @Id 注解: " + magicEntity.getClazz().getName());
        Object idValue = getIdValue(entity, magicEntity.getClazz(),
                magicEntity.getFields().get(magicEntity.getIdIndex()).getName());
        if (idValue == null)
            throw new MagicException("删除实体对象的id不能为空: " + entity);
        int result;
        MagicEntityUtils.execPreUpdate(magicEntity, entity);
        result = sqlSession.delete(getSqlId(magicEntity, "deleteById"), idValue);
        MagicEntityUtils.execPostUpdate(magicEntity, entity);
        return result;
    }

    @Override
    public int deleteById(Class<?> entityClass, Object id) {
        if (id == null)
            throw new MagicException("实体对象的id不能为空: " + entityClass.getName());
        MagicEntity magicEntity = getMagicEntity(entityClass);
        if (magicEntity.getIdAnn() == null)
            throw new MagicException("删除实体类没有 @Id 注解: " + magicEntity.getClazz().getName());
        return sqlSession.delete(getSqlId(magicEntity, "deleteById"), id);
    }

    @Override
    public <T, E> int deleteByIds(Class<T> entityClass, List<E> ids) {
        if (ids == null || ids.isEmpty())
            throw new MagicException("实体对象的ids集合不能为空: " + entityClass.getName());
        ids.forEach(o -> {
            if (o == null)
                throw new MagicException("实体对象的ids集合存在空值：" + ids);
        });
        MagicEntity magicEntity = getMagicEntity(entityClass);
        if (magicEntity.getIdAnn() == null)
            throw new MagicException("删除实体类没有 @Id 注解: " + magicEntity.getClazz().getName());
        return sqlSession.delete(getSqlId(magicEntity, "deleteByIds"), ids);
    }

    @Override
    public <T> T getMapper(Class<T> type) {
        return configuration.getMapper(type, sqlSession);
    }

    @Override
    public <T> String selectTableSql(Class<T> entityClass) {
        MagicEntity magicEntity = getMagicEntity(entityClass);
        return magicEntity.getSelectTableSql();
    }

    @Override
    public String getTableName(Class<?> entityClass) {
        MagicEntity magicEntity = getMagicEntity(entityClass);
        return magicEntity.getTableName();
    }

    @Override
    public void executeSqlScript(String sqlScript) {
        Connection connection = sqlSession.getConnection();
        MagicUtils.exeScript(sqlScript, connection);
    }

    @Override
    public void executeSqlScript(File scriptFile) {
        if (scriptFile == null)
            throw new MagicException("sql脚本文件不存在");
        if (!scriptFile.exists())
            throw new MagicException("sql脚本文件不存在：" + scriptFile.getAbsolutePath());
        executeSqlScript(FileUtil.readString(scriptFile, StandardCharsets.UTF_8));
    }

    @Override
    public Connection getConnection() {
        return sqlSession.getConnection();
    }

    @Override
    public String getUrl() {
        Connection connection = getConnection();
        return MagicUtils.getDatabaseURL(connection, true);
    }

    @Override
    public String getProtocol() {
        if (protocol == null) {
            String url = getUrl();
            protocol = url.toLowerCase().split(":")[1];
        }
        return protocol;
    }
}
